import torch

import torch.nn as nn


def compute_ideal_plasticity_loss(F, stress):

    I = torch.eye(3, device=F.device).unsqueeze(0)  # [1, 3, 3]
    strain = 0.5 * (torch.bmm(F.transpose(1, 2), F) - I)  # [n_particles, 3, 3]
    
    def elastic_linearity_constraint(stress, strain):
        
        stress_vm = compute_von_mises(stress)  # [n_particles]
        strain_vm = compute_von_mises(strain)  # [n_particles]
        
        
        stress_vm_norm = stress_vm / (stress_vm.mean() + 1e-6)
        strain_vm_norm = strain_vm / (strain_vm.mean() + 1e-6)
        
        
        elastic_modulus = stress_vm_norm / (strain_vm_norm + 1e-6)
        
        
        modulus_mean = elastic_modulus.mean()
        modulus_std = elastic_modulus.std()
        modulus_cv = modulus_std / (modulus_mean + 1e-6)
        
        return modulus_cv
    
    def constant_yield_stress_constraint(stress):
        
        stress_vm = compute_von_mises(stress)  # [n_particles]
        
        
        stress_vm_norm = stress_vm / (stress_vm.mean() + 1e-6)
        
        
        mean_stress = stress_vm.mean()
        yielded_mask = stress_vm > mean_stress * 0.5
        
        if not yielded_mask.any():
            return torch.tensor(0.0, device=stress.device)
            
        
        yielded_stress = stress_vm_norm[yielded_mask]
        stress_mean = yielded_stress.mean()
        stress_std = yielded_stress.std()
        stress_cv = stress_std / (stress_mean + 1e-6)
        
        return stress_cv
    
    losses = {
        'elastic_linearity': elastic_linearity_constraint(stress, strain),
        'constant_yield': constant_yield_stress_constraint(stress)
    }
    
    return losses

def compute_von_mises(tensor):


    mean = (tensor[:, 0, 0] + tensor[:, 1, 1] + tensor[:, 2, 2]).unsqueeze(-1).unsqueeze(-1) / 3
    
    deviatoric = tensor - mean * torch.eye(3, device=tensor.device).unsqueeze(0)
    
    von_mises = torch.sqrt(1.5 * torch.sum(deviatoric * deviatoric, dim=(1,2)))
    return von_mises
    


    
def volume_preserve(F):
    det = torch.det(F)
    return torch.mean((det - 1.0)**2)

def stress_symmetry(stress):
    skew = stress - stress.transpose(-1, -2)
    return torch.mean(torch.norm(skew, dim=(-2,-1))**2) * 1e-5

def incompressible_plasticity(F_plastic):
    det_p = torch.det(F_plastic)
    return torch.mean((det_p - 1.0)**2)

def linear_elasticity(F, stress):
    I = torch.eye(3, device=F.device).unsqueeze(0)  # [1, 3, 3]
    strain = 0.5 * (torch.bmm(F.transpose(1, 2), F) - I)  # [n_particles, 3, 3]
    
    
        
    stress_vm = compute_von_mises(stress)  # [n_particles]
    strain_vm = compute_von_mises(strain)  # [n_particles]
    
    
    stress_vm_norm = stress_vm / (stress_vm.mean() + 1e-6)
    strain_vm_norm = strain_vm / (strain_vm.mean() + 1e-6)
    
    
    elastic_modulus = stress_vm_norm / (strain_vm_norm + 1e-6)
    
    
    modulus_mean = elastic_modulus.mean()
    modulus_std = elastic_modulus.std()
    modulus_cv = modulus_std / (modulus_mean + 1e-6)
    
    return modulus_cv

    

class GatingNetwork(nn.Module):
    def __init__(self, input_dim=18, num_constraints=3):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Linear(32, num_constraints),
            nn.Sigmoid()
        )
        
    def forward(self, F, stress):
        F_flatten = F.view(F.size(0), -1)
        stress_flatten = stress.view(stress.size(0), -1)
        x = torch.cat([F_flatten, stress_flatten], dim=1)
        weights = self.mlp(x)       # (batch_size, num_constraints)
        return weights
    
def cosine_similarity(grad1, grad2):
    dot = sum(torch.sum(g1 * g2) for g1, g2 in zip(grad1, grad2))
    norm1 = sum(torch.sum(g**2) for g in grad1)**0.5
    norm2 = sum(torch.sum(g**2) for g in grad2)**0.5
    return dot / (norm1 * norm2 + 1e-8)
    
def project_grad(g_aux, g_main):
    dot = sum(torch.sum(g1 * g2) for g1, g2 in zip(g_aux, g_main))
    norm_main_sq = sum(torch.sum(g**2) for g in g_main)
    proj_coeff = dot / (norm_main_sq + 1e-8)
    
    g_proj = [g - proj_coeff * g_m for g, g_m in zip(g_aux, g_main)]
    return g_proj
    
def compute_weight(cos_sim, threshold=0.3, min_weight=0.0):
    if cos_sim < threshold:
        return min_weight
    else:
        return (cos_sim - threshold) / (1 - threshold) 